[MUSA][9/N] Add FA3 attention backend support through MATE (MUSA AI Tensor Engine)#22051
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces support for the MUSA (Moore Threads GPU) hardware backend, specifically focusing on Flash Attention integration. It adds necessary dependencies, configuration parameters, and a new MUSA-specific attention module that wraps the mate library's flash attention functions. The implementation uses a thread-local context manager to automatically inject scheduler metadata into attention calls. Key changes include updates to the attention registry, the FlashAttentionBackend to handle MUSA-specific logic, and server argument adjustments for MUSA compatibility. Feedback highlights potential issues with global buffer safety in multi-GPU environments, metadata cache collisions due to non-unique keys, and the implications of ignoring cu_seqlens_k_new in the MUSA implementation.
yeahdongcn
left a comment
There was a problem hiding this comment.
I think it would be better to split this into two commits: one carrying over changes from the previous PR, and another fixing the regression in selecting FA kernels for different NVIDIA GPU architectures. This should make it easier for the SGLang core team to review.
3369ebb to
ba20eee
Compare
9cb257c to
0af5fe5
Compare
|
/tag-and-rerun-ci |
|
/rerun-failed-ci |
4 similar comments
|
/rerun-failed-ci |
|
/rerun-failed-ci |
|
/rerun-failed-ci |
|
/rerun-failed-ci |
0af5fe5 to
05ff0c4
Compare
|
/rerun-failed-ci |
05ff0c4 to
1efadad
Compare
|
/rerun-failed-ci |
1efadad to
9bfa839
Compare
9bfa839 to
4ef02b1
Compare
4ef02b1 to
915d2da
Compare
915d2da to
782afec
Compare
|
/rerun-failed-ci |
3 similar comments
|
/rerun-failed-ci |
|
/rerun-failed-ci |
|
/rerun-failed-ci |
|
Hi @Fridge003 and @Kangyan-Zhou, all NVIDIA CI checks have passed. Could you please take a look if we can merge this? Thanks! |
…ensor Engine) (#22051) Co-authored-by: zhiguo.qin <zhiguo.qin@mthreads.com>
…ensor Engine) (sgl-project#22051) Co-authored-by: zhiguo.qin <zhiguo.qin@mthreads.com>
Motivation
This PR fixes the Flash Attention backend support that was previously merged in PR #17985 but later reverted in PR #22002 due to a bug. The original commit 2373552 caused CI failures (see failed CI job).
Previously, the MUSA-adapted flash attention implementation had a bug in the
_forward_extend_implmethod. The code was missing a proper mechanism to select the correct kernel implementation based on thefa_impl_verparameter, causing it to always use the default FA3 implementation regardless of the specified version.Fix Applied
After rebasing to the latest main branch, the kernel selection logic has been refactored and moved to the
FlashAttentionBackend.__init__method. This ensures that the appropriate flash attention implementation is selected during initialization based on thefa_impl_verparameter.Moved kernel selection to
__init__: The logic to select the correct flash attention kernel (including MUSA-specific implementations) is now handled in theFlashAttentionBackend.__init__method, where two instance variables are initialized:self.flash_attn_with_kvcache: For cached attention operationsself.flash_attn_varlen_func: For variable-length attention operationsUpdated forward methods: Both
_forward_extend_impland_forward_decode_implnow use these instance variables instead of directly calling the default implementations, ensuring the correct kernel is used based on the initialized configuration.Accuracy Tests
Speed Tests and Profiling
Checklist
Review and Merge Process
/tag-and-rerun-ci,/tag-run-ci-label,/rerun-failed-ciRelated Links: